sink('./generated/appendix_2/natural_nonlinear_correct.txt')
source("./04_simulation_00_setup.R")

load("./generated/appendix_2/naturalistic_sim_data.RData")

start_time <- Sys.time()

set.seed(6725498)

samp_sizes <- c(500, 1000, 4000)
nboot = 1

full_sims_nat <- NULL

for(samp_size in samp_sizes){
  for(correct in "bart") {
    for(correct_sample in c("Yes")) {
      print(correct)
      print(samp_size)
      # print(Sys.time())
      
      formula_outcome <- as.formula(paste0("Y ~ treat_ind + ", paste0(covariates, collapse = " + ")))
      
      if(correct_sample == "Yes") {
        formula_weights <- as.formula(paste0("S ~ ", paste0(covariates, collapse = " + ")))  
      } else {
        ## Only religious and PID
        formula_weights <- as.formula(paste0("S ~ ", paste0(covariates[-c(1, 2, 3, 4, 9, 10, 11, 12)], collapse = " + ")))
      }
      
      sims <- mclapply(1:nsim, function(i) {
        ## set.seed within sim so match across files
        set.seed(i + 10000)
        
        if(correct == "bart"){
          expt_include <- expt_include_bart
        } else {
          expt_include <- expt_include_linear
        }
        
        ## draw a sample from the population
        exp_data <- sample_n(pop, size = samp_size, replace = TRUE, weight = expt_include)
        exp_data$treat_ind <- sample(rep(c(0, 1), each = ceiling(nrow(exp_data)/2))[1:nrow(exp_data)], replace = FALSE)
        
        ## make a small pop for estimation
        pop_subset <- pop %>% sample_n(5000)
        
        if(correct == "bart"){
          exp_data <- exp_data %>%
            mutate(Y0_bart = Y0_bart + rnorm(n(), 0, y0_sd),
                   Y1_bart = Y1_bart + rnorm(n(), 0, y1_sd),
                   Y = ifelse(treat_ind, Y1_bart, Y0_bart))
          
          PATE = data.frame(estimator = "PATE", type = "PATE", est = mean(pop$Y1_bart - pop$Y0_bart))
          fs_PATE = data.frame(estimator = "fs_PATE", type = "PATE", est = mean(pop_subset$Y1_bart - pop_subset$Y0_bart))
        }
        
        if(correct == "ols"){
          exp_data <- exp_data %>%
            mutate(Y0_linear = Y0_linear + rnorm(n(), 0, y0_sd),
                   Y1_linear = Y1_linear + rnorm(n(), 0, y1_sd),
                   Y = ifelse(treat_ind, Y1_linear, Y0_linear))
          
          PATE = data.frame(estimator = "PATE", type = "PATE", est = mean(pop$Y1_linear - pop$Y0_linear))
          fs_PATE = data.frame(estimator = "fs_PATE", type = "PATE", est = mean(pop_subset$Y1_linear - pop_subset$Y0_linear))
        }
        
        ## SATE from expt
        ipw_est <- tpate(formula_outcome = formula_outcome,
                         formula_weights = formula_weights,
                         exp_data = exp_data, pop_data = pop_subset,
                         treatment = "treat_ind",
                         est_type = "ipw", weights_type = "logit",
                         sims = nboot,
                         compute_sate = TRUE,
                         boot = FALSE,
                         numCores = 1)
        
        sate_est <- data.frame(estimator = "SATE",  type = "SATE",
                               ipw_est$sate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        # 1. Weighting-based Estimator
        ipw_logit <- data.frame(estimator = "IPW-logit", type = "weighting",
                                ipw_est$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        wls_logit <- data.frame(estimator = "wLS-logit", type = "weighting",
                                tpate(formula_outcome = formula_outcome,
                                      formula_weights = formula_weights,
                                      exp_data = exp_data, pop_data = pop_subset,
                                      treatment = "treat_ind",
                                      est_type = "wls", weights_type = "logit",
                                      sims = nboot,
                                      compute_sate = FALSE,
                                      boot = FALSE,
                                      numCores = 1)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        
        # 2. Outcome-based Estimator
        out_ols <- data.frame(estimator = "OLS-proj", type = "outcome",
                              tpate(formula_outcome = formula_outcome,
                                    formula_weights = formula_weights,
                                    exp_data = exp_data, pop_data = pop_subset,
                                    treatment = "treat_ind",
                                    est_type = "outcome-ols",
                                    sims = nboot,
                                    compute_sate = FALSE,
                                    boot = TRUE,
                                    numCores = 1)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        out_bart <- data.frame(estimator = "BART-proj", type = "outcome",
                               tpate(formula_outcome = formula_outcome,
                                     formula_weights = formula_weights,
                                     exp_data = exp_data, pop_data = pop_subset,
                                     treatment = "treat_ind",
                                     est_type = "outcome-bart",
                                     sims = nboot,
                                     compute_sate = FALSE,
                                     boot = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        # 3. DR Estimator
        
        dr_logit_ols <- data.frame(estimator = "DR-OLS-logit", type = "doubly robust",
                                   tpate(formula_outcome = formula_outcome,
                                         formula_weights = formula_weights,
                                         exp_data = exp_data, pop_data = pop_subset,
                                         treatment = "treat_ind", weights_type = "logit",
                                         est_type = "dr-ols",
                                         sims = nboot,
                                         compute_sate = FALSE,
                                         boot = FALSE,
                                         numCores = 1)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        dr_logit_bart <- data.frame(estimator = "DR-BART-logit", type = "doubly robust",
                                    tpate(formula_outcome = formula_outcome,
                                          formula_weights = formula_weights,
                                          exp_data = exp_data, pop_data = pop_subset,
                                          treatment = "treat_ind", weights_type = "logit",
                                          est_type = "dr-bart",
                                          sims = nboot,
                                          compute_sate = FALSE,
                                          boot = FALSE)$tpate[c("est", "se", "ci_lower", "ci_upper")] %>% data.frame())
        
        results = data.frame(sim = i, samp_size = samp_size, bind_rows(tryCatch(sate_est, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(ipw_logit, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(wls_logit, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(out_ols, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(out_bart, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(dr_logit_ols, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(dr_logit_bart, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(PATE, error = function(e) data.frame(est = NA)),
                                                                       tryCatch(fs_PATE, error = function(e) data.frame(est = NA)))
        )
        
        return(results)
      }, mc.cores = detectCores() - 3) 
      
      # remove NA rows
      sims <- sims[sapply(sims, class) == "data.frame"]
      sims <- sims %>% bind_rows()  
      
      sims <- sims %>%
        mutate(correct_model = correct,
               correct_sample = correct_sample)
      
      full_sims_nat <- bind_rows(full_sims_nat, sims)
    }
  } 
}

naturalistic_sims_nonlinear_samp_correct <- full_sims_nat

save(naturalistic_sims_nonlinear_samp_correct, file = "./generated/appendix_2/naturalistic_sims_nonlinear_samp_correct.RData")
sink()

sink('./generated/appendix_2/natural_sims_time_07.txt')
end_time <- Sys.time()
end_time - start_time
sink()